import paddle

class ConvUnit(paddle.nn.Sequential):
    def __init__(self, in_dim, ou_dim, kernel_size=1, stride=1):
        """
        初始化卷积单元，带有激活函数
        params:
        - in_dim     (int): 输入维度
        - ou_dim     (int): 输出维度
        - kernel_size(int): 卷积大小
        - stride     (int): 滑动步长
        """
        super().__init__(
            paddle.nn.Conv2D(      # 卷积函数
                in_dim, ou_dim, kernel_size, stride, padding='same',
                weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal()),
                bias_attr=False
            ),
            paddle.nn.BatchNorm2D( # 批归一化
                ou_dim,
                weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal()),
                bias_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal())
            ),
            paddle.nn.LeakyReLU()  # 激活函数
        )
        
class ConvNorm(paddle.nn.Sequential):
    def __init__(self, in_dim, ou_dim, kernel_size=1, stride=1):
        """
        初始化卷积单元，不带激活函数
        params:
        - in_dim     (int): 输入维度
        - ou_dim     (int): 输出维度
        - kernel_size(int): 卷积大小
        - stride     (int): 滑动步长
        """
        super().__init__(
            paddle.nn.Conv2D(      # 卷积函数
                in_dim, ou_dim, kernel_size, stride, padding='same',
                weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal()),
                bias_attr=False
            ),
            paddle.nn.BatchNorm2D( # 批归一化
                ou_dim,
                weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal()),
                bias_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal())
            )
        )

class ProjUint(paddle.nn.Sequential):
    def __init__(self, in_dim, ou_dim, stride=1):
        """
        初始化投影单元，改变特征大小和维度
        params:
        - in_dim(int): 输入维度
        - ou_dim(int): 输出维度
        - stride(int): 滑动步长
        """
        super().__init__(
            paddle.nn.AvgPool2D(   # 均匀池化
                kernel_size=stride, stride=stride, padding=0
            ),
            paddle.nn.Conv2D(      # 卷积函数
                in_dim, ou_dim, kernel_size=1, stride=1, padding=0,
                weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal()),
                bias_attr=False
            ),
            paddle.nn.BatchNorm2D( # 批归一化
                ou_dim,
                weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal()),
                bias_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal())
            )
        )

class SSRSplit(paddle.nn.Layer):
    def __init__(self, in_dim, ou_dim, stride=1, splits=1):
        """
        初始化分割单元
        params:
        - in_dim (int): 输入维度
        - ou_dim (int): 输出维度
        - stride (int): 滑动步长
        - splits (int): 分割次数，分割尺度为2^n
        """
        super().__init__()
        # 输入参数检查
        assert stride in [1, 2], '错误：滑动步长必须为1或2！'
        assert splits >= 0, '错误：分割次数必须大于等于0！'
        
        # 设置分割变量
        self.splits = splits # 分割次数
        # self.dimensions=[] # 维度列表
        self.split_list = [] # 分割列表
        
        # 添加分割列表
        split_item = self.add_sublayer( # 添加第一个分割项目
            'split_' + str(0),
            ConvUnit(in_dim, ou_dim, kernel_size=3, stride=stride)
        )
        self.split_list.append(split_item)
        
        for i in range(self.splits):    # 添加剩余分的割项目
            # 添加维度列表
            if i < self.splits:
                ou_dim //= 2
                # self.dimensions.append(ou_dim)
            
            # 添加分割列表
            split_item = self.add_sublayer(
                'split_' + str(i+1),
                ConvUnit(ou_dim, ou_dim, kernel_size=3, stride=1)
            )
            self.split_list.append(split_item)
        
    def forward(self, x):
        # 提取特征
        x_list = [] # 特征列表
        for i, split_item in enumerate(self.split_list):
            if i < self.splits:
                x = split_item(x)
                # x_item, x = paddle.split(x, num_or_sections=[-1, self.dimensions[i]], axis=1)
                x_item, x = paddle.split(x, num_or_sections=2, axis=1)# 使用该接口时，通道维度为2^n
                x_list.append(x_item)
            else:
                x = split_item(x)
                x_list.append(x)

        # 合并特征
        x = paddle.concat(x_list, axis=1)
        
        return x

class UpSample(paddle.nn.Sequential):
    def __init__(self, in_dim, ou_dim, expand=1):
        """
        初始化放大单元，改变特征大小和维度
        params:
        - in_dim(int): 输入维度
        - ou_dim(int): 输出维度
        - expand(int): 放大倍数
        """
        super().__init__(
            paddle.nn.Upsample(    # 线性插值
                scale_factor=expand, mode='bilinear'
            ),
            paddle.nn.Conv2D(      # 卷积函数
                in_dim, ou_dim, kernel_size=1, stride=1, padding=0,
                weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal()),
                bias_attr=False
            ),
            paddle.nn.BatchNorm2D( # 批归一化
                ou_dim,
                weight_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal()),
                bias_attr=paddle.ParamAttr(initializer=paddle.nn.initializer.KaimingNormal())
            )
        )

class SSRMerge(paddle.nn.Layer):
    def __init__(self, in_dim, ou_dim, expand=0):
        """
        初始化合并结构
        params:
        - in_dim(int): 输入维度
        - ou_dim(int): 输出维度
        - expand(int): 放大倍数
        """
        super().__init__()
        
        # 添加上采样层
        self.upsample = UpSample(in_dim, ou_dim, expand)
        
    def forward(self, x_item, x):
        # 放大特征
        x = self.upsample(x)
        
        # 合并特征
        x = paddle.concat([x_item, x], axis=1)
        
        return x